Batch Normalization from scratch


In [1]:
import mxnet as mx
import numpy as np
mx.random.seed(1)

Context


In [2]:
ctx = mx.gpu()

Batch normalization

  • Normally used before the activation layer

In [3]:
def pure_batch_norm(X, gamma, beta, eps = 1e-5):
    if len(X.shape) not in (2, 4):
        raise ValueError('only supports dense or 2dconv')

    # dense
    if len(X.shape) == 2:
        # mini-batch mean
        mean = mx.nd.mean(X, axis=0)
        # mini-batch variance
        variance = mx.nd.mean((X - mean) ** 2, axis=0)
        # normalize
        X_hat = (X - mean) * 1.0 / mx.nd.sqrt(variance + eps)
        # scale and shift
        out = gamma * X_hat + beta

    # 2d conv
    elif len(X.shape) == 4:
        # extract the dimensions
        N, C, H, W = X.shape
        # mini-batch mean
        mean = mx.nd.mean(X, axis=(0, 2, 3))
        # mini-batch variance
        variance = mx.nd.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
        # normalize
        X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / mx.nd.sqrt(variance.reshape((1, C, 1, 1)) + eps)
        # scale and shift
        out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))

    return out

Example


In [4]:
A = mx.nd.array([1, 2, 3, 6, 5, 7], ctx=ctx).reshape((3, 2))
A


Out[4]:
[[1. 2.]
 [3. 6.]
 [5. 7.]]
<NDArray 3x2 @gpu(0)>

In [5]:
pure_batch_norm(X=A,
                gamma=mx.nd.array([1,1], ctx=ctx),
                beta=mx.nd.array([0,0], ctx=ctx))


Out[5]:
[[-1.2247427  -1.3887286 ]
 [ 0.          0.46290955]
 [ 1.2247427   0.9258191 ]]
<NDArray 3x2 @gpu(0)>

In [6]:
B = mx.nd.array([1,6,5,7,4,3,2,5,6,3,2,4,5,3,2,5,6], ctx=ctx).reshape((2, 2, 2, 2))
B


Out[6]:
[[[[1. 6.]
   [5. 7.]]

  [[4. 3.]
   [2. 5.]]]


 [[[6. 3.]
   [2. 4.]]

  [[5. 3.]
   [2. 5.]]]]
<NDArray 2x2x2x2 @gpu(0)>

In [7]:
# 1st sample, 1st layer
B[0, 0, :, :]


Out[7]:
[[1. 6.]
 [5. 7.]]
<NDArray 2x2 @gpu(0)>

In [8]:
# 1st sample, 2nd layer
B[0, 1, :, :]


Out[8]:
[[4. 3.]
 [2. 5.]]
<NDArray 2x2 @gpu(0)>

In [9]:
pure_batch_norm(X=B,
                gamma=mx.nd.array([1,1], ctx=ctx),
                beta=mx.nd.array([0,0], ctx=ctx))


Out[9]:
[[[[-1.637844    0.881916  ]
   [ 0.377964    1.385868  ]]

  [[ 0.30779248 -0.51298743]
   [-1.3337674   1.1285723 ]]]


 [[[ 0.881916   -0.62994   ]
   [-1.1338919  -0.12598799]]

  [[ 1.1285723  -0.51298743]
   [-1.3337674   1.1285723 ]]]]
<NDArray 2x2x2x2 @gpu(0)>

In [10]:
B_normalized = pure_batch_norm(X=B,
                               gamma=mx.nd.array([1,1], ctx=ctx),
                               beta=mx.nd.array([0,0], ctx=ctx))

In [11]:
B_normalized[0, 0, :, :]


Out[11]:
[[-1.637844  0.881916]
 [ 0.377964  1.385868]]
<NDArray 2x2 @gpu(0)>